
import torch
import torchvision
from datasets import load_dataset

class AdjustedDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset):
        self.dataset = original_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        batch = self.dataset[idx]
        images = batch['image']
        targets = batch['label']
        return images, targets

def get_transformations(args):
    '''Load Transformations'''
    crop = torchvision.transforms.RandomCrop(args.image_size, padding=4)
    h_flip = torchvision.transforms.RandomHorizontalFlip()
    jitter = torchvision.transforms.RandomApply([torchvision.transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p = 0.8)
    grayscale = torchvision.transforms.RandomGrayscale(p=0.2)
    normalize = torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    
    transformation_dist = []
    if 'C' in args.transformation:
        transformation_dist.append(crop)
    if 'H' in args.transformation:
        transformation_dist.append(h_flip)
    if 'J' in args.transformation:
        transformation_dist.append(jitter)
    if 'G' in args.transformation:
        transformation_dist.append(grayscale)
    if 'N' in args.transformation:
        transformation_dist.append(normalize)     
        
    transform_train = torchvision.transforms.Compose(transformation_dist)        
    transform_train_standard = torchvision.transforms.Compose([crop, h_flip, normalize])        
    transform_test = torchvision.transforms.Compose([normalize])
    to_tensor = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    transformations = {
        'train':transform_train,
        'train_standard':transform_train_standard,
        'test': transform_test,
        'to_tensor': to_tensor,
    }
    
    return transformations
    
def get_data(args, transform):
    '''Load Datasets'''
    if args.dataset == 'CIFAR10':
        num_classes=10
        train_set = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform)
        held_out = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform)
    elif args.dataset == 'CIFAR100':
        num_classes=100
        train_set = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform)
        held_out = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform)
    elif args.dataset == 'TinyImageNet':
        orig_dataset = load_dataset("Maysee/tiny-imagenet")
        num_classes = orig_dataset['train'].features['label'].num_classes # 200

        def preprocess_images(example):
            image = example['image'].convert("RGB") if example['image'].mode !="RGB" else example['image']
            example['image'] = transform(image)
            return example
        
        dataset = orig_dataset.map(preprocess_images, batched=False)
        dataset.set_format(type='torch', columns=['image', 'label'])
          
        train_set = AdjustedDataset(dataset['train'])
        held_out = AdjustedDataset(dataset['valid'])
    else:
        raise ValueError(f"Dataset {args.model_type} not defined")
    
    # Split held out data into test and validation set
    test_size = int(len(held_out) * 0.8) # 80% training data
    val_size = len(held_out) - test_size # 20% validation data
    test_set, val_set = torch.utils.data.random_split(held_out, [test_size, val_size])
    
    forget_size = int(len(train_set) * args.forget_data_ratio) # forget data size
    retain_size = len(train_set) - forget_size # retain data size
    
    # Forget and retain index split
    retain_set, forget_set = torch.utils.data.random_split(train_set, [retain_size, forget_size])

    print("Len of train set:", len(train_set))
    print("Len of test set:", len(test_set))
    print("Len of val set:", len(val_set))
    print("Len of retain set:", len(retain_set))
    print("Len of forget set:", len(forget_set))

    datasets = {
        'train': train_set,
        'retain': retain_set,
        'forget': forget_set,
        'test': test_set,
        'val': val_set,
        'num_classes': num_classes,
    }
    
    return datasets


def get_dataloaders(args, datasets):
    '''Load DataLoaders'''
    train = torch.utils.data.DataLoader(datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    retain = torch.utils.data.DataLoader(datasets['retain'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    forget = torch.utils.data.DataLoader(datasets['forget'], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    test = torch.utils.data.DataLoader(datasets['test'], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    val = torch.utils.data.DataLoader(datasets['val'], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)  

    dataloaders = {
        'train': train,
        'retain': retain,
        'forget': forget,
        'test': test,
        'val': val,
    }
    
    return dataloaders
